import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
from exllama_lib.model import ExLlama, ExLlamaCache, ExLlamaConfig
from opencompass.models.base import BaseModel
from opencompass.registry import MODELS
from opencompass.utils.logging import get_logger
from opencompass.utils.prompt import PromptList
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    LlamaForCausalLM,
    LlamaTokenizer,
    PretrainedConfig,
    PreTrainedModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast

PromptType = Union[PromptList, str]

os.environ["TOKENIZERS_PARALLELISM"] = "false"

tok_ins = "\n\n### Instruction:\n"
tok_res = "\n\n### Response:\n"
prompt_input = tok_ins + "{instruction}" + tok_res


class ExllamaHF(PreTrainedModel):
    def __init__(self, config: ExLlamaConfig):
        super().__init__(PretrainedConfig())
        self.ex_config = config
        self.ex_model = ExLlama(self.ex_config)
        self.generation_config = GenerationConfig()
        self.lora = None

        self.ex_cache = ExLlamaCache(self.ex_model)
        self.past_seq = None

        self.ex_cache_negative = ExLlamaCache(self.ex_model)
        self.past_seq_negative = None

    def _validate_model_class(self):
        pass

    def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
        pass

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids, **kwargs}

    @property
    def device(self) -> torch.device:
        return torch.device(0)

    def __call__(self, *args, **kwargs):
        use_cache = kwargs.get("use_cache", True)
        labels = kwargs.get("labels", None)
        past_key_values = kwargs.get("past_key_values", None)
        ignore_token = kwargs.get("ignore_token", None)
        mask_length = kwargs.get("mask_length", None)

        if len(args) > 0:
            input_ids = args[0]
        else:
            input_ids = kwargs["input_ids"]
        is_negative = False
        past_seq = self.past_seq
        ex_cache = self.ex_cache

        seq = input_ids[0].tolist()
        if is_negative and past_key_values is not None:
            seq = past_key_values + seq

        seq_tensor = torch.tensor(seq)

        # Make the forward call
        if labels is None:
            if past_seq is None or not torch.equal(
                past_seq, seq_tensor[:-1]
            ):
                ex_cache.current_seq_len = 0
                self.ex_model.forward(
                    torch.tensor([seq[:-1]], dtype=torch.long),
                    ex_cache,
                    preprocess_only=True,
                    lora=self.lora,
                )

            logits = self.ex_model.forward(
                torch.tensor([seq[-1:]], dtype=torch.long),
                ex_cache,
                lora=self.lora,
            ).to(input_ids.device)
        else:
            ex_cache.current_seq_len = 0
            logits = self.ex_model.forward(
                torch.tensor([seq], dtype=torch.long),
                ex_cache,
                last_id_only=False,
                lora=self.lora,
            ).to(input_ids.device)

        if is_negative:
            self.past_seq_negative = seq_tensor
        else:
            self.past_seq = seq_tensor

        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            loss_fct = torch.nn.CrossEntropyLoss(
                reduction="none", ignore_index=ignore_token
            )

            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
            ).view(shift_labels.size())

            if mask_length is not None:
                mask = torch.zeros_like(shift_labels)  # [batch,seqlen]
                for i in range(len(mask)):
                    for j in range(mask_length[i] - 1, len(mask[i])):
                        mask[i][j] = 1
                loss = loss * mask

            lens = mask_length

            if ignore_token:
                lens = (labels != ignore_token).sum(-1).cpu().numpy()
                if mask_length is not None:
                    lens -= np.array(mask_length)

            loss = loss.sum(-1).cpu().detach().numpy() / lens

        return CausalLMOutputWithPast(
            logits=logits,
            past_key_values=seq if use_cache else None,
            loss=loss,
        )

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Optional[
            Union[str, os.PathLike]
        ],
        *model_args,
        **kwargs,
    ):
        assert (
            len(model_args) == 0 and len(kwargs) == 0
        ), "extra args is currently not supported"
        if isinstance(pretrained_model_name_or_path, str):
            pretrained_model_name_or_path = Path(
                pretrained_model_name_or_path
            )

        config = ExLlamaConfig(
            pretrained_model_name_or_path / "config.json"
        )

        weight_path = None
        for ext in [".safetensors", ".pt", ".bin"]:
            found = list(pretrained_model_name_or_path.glob(f"*{ext}"))
            if len(found) > 0:
                weight_path = found[-1]
                break
        assert (
            weight_path is not None
        ), f'could not find weight in "{pretrained_model_name_or_path}"'

        config.model_path = str(weight_path)
        config.max_seq_len = 2048
        config.compress_pos_emb = 1

        if torch.version.hip:
            config.rmsnorm_no_half2 = True
            config.rope_no_half2 = True
            config.matmul_no_half2 = True
            config.silu_no_half2 = True

        # This slowes down a bit but align better with autogptq generation.
        # TODO: Should give user choice to tune the exllama config
        # config.fused_attn = False
        # config.fused_mlp_thd = 0

        return ExllamaHF(config)


@MODELS.register_module()
class HuggingFace(BaseModel):
    """Model wrapper around HuggingFace general models.

    Args:
        path (str): The name or path to HuggingFace's model.
        hf_cache_dir: Set the cache dir to HF model cache dir. If None, it will
            use the env variable HF_MODEL_HUB. Defaults to None.
        max_seq_len (int): The maximum length of the input sequence. Defaults
            to 2048.
        tokenizer_path (str): The path to the tokenizer. Defaults to None.
        tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
            Defaults to {}.
        peft_path (str, optional): The name or path to the HuggingFace's PEFT
            model. If None, the original model will not be converted to PEFT.
            Defaults to None.
        tokenizer_only (bool): If True, only the tokenizer will be initialized.
            Defaults to False.
        model_kwargs (dict): Keyword arguments for the model, used in loader.
            Defaults to dict(device_map='auto').
        meta_template (Dict, optional): The model's meta prompt
            template if needed, in case the requirement of injecting or
            wrapping of any meta instructions.
        extract_pred_after_decode (bool): Whether to extract the prediction
            string from the decoded output string, instead of extract the
            prediction tokens before decoding. Defaults to False.
        batch_padding (bool): If False, inference with be performed in for-loop
            without batch padding.

    Note:
        About ``extract_pred_after_decode``: Commonly, we should extract the
        the prediction tokens before decoding. But for some tokenizers using
        ``sentencepiece``, like LLaMA,  this behavior may change the number of
        whitespaces, which is harmful for Python programming tasks.
    """

    def __init__(
        self,
        path: str,
        hf_cache_dir: Optional[str] = None,
        max_seq_len: int = 2048,
        tokenizer_path: Optional[str] = None,
        tokenizer_kwargs: dict = dict(),
        peft_path: Optional[str] = None,
        tokenizer_only: bool = False,
        model_kwargs: dict = dict(device_map="auto"),
        meta_template: Optional[Dict] = None,
        extract_pred_after_decode: bool = False,
        batch_padding: bool = False,
    ):
        super().__init__(
            path=path,
            max_seq_len=max_seq_len,
            tokenizer_only=tokenizer_only,
            meta_template=meta_template,
        )
        from opencompass.utils.fileio import patch_hf_auto_model

        if hf_cache_dir is None:
            hf_cache_dir = os.getenv("HF_MODEL_HUB", None)
        patch_hf_auto_model(hf_cache_dir)
        self.logger = get_logger()
        self.config = AutoConfig.from_pretrained(path, **model_kwargs)
        self._load_tokenizer(
            path=path,
            tokenizer_path=tokenizer_path,
            tokenizer_kwargs=tokenizer_kwargs,
        )
        self.batch_padding = batch_padding
        self.extract_pred_after_decode = extract_pred_after_decode
        if not tokenizer_only:
            self._load_model(
                path=path,
                model_kwargs=model_kwargs,
                peft_path=peft_path,
            )

    def _load_tokenizer(
        self,
        path: str,
        tokenizer_path: Optional[str],
        tokenizer_kwargs: dict,
    ):
        self.logger.warning(f"tokenizer path: {path}")
        if "llama" in self.config.model_type.lower():
            self.tokenizer = LlamaTokenizer.from_pretrained(
                tokenizer_path if tokenizer_path else path,
                **tokenizer_kwargs,
            )
            if self.tokenizer.pad_token_id is None:
                special_tokens_dict = dict(pad_token="<pad>")
                self.tokenizer.add_special_tokens(special_tokens_dict)
            # tok_ins = "\n\n### Instruction:\n"
            # tok_res = "\n\n### Response:\n"
            DEFAULT_PAD_TOKEN = "<pad>"
            # special_tokens_dict = {}
            # special_tokens_dict['additional_special_tokens'] = [tok_ins, tok_res]
            # num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict)
            if self.tokenizer.pad_token_id is None:
                special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
                self.tokenizer.add_special_tokens(special_tokens_dict)
            # self.logger.warning(f"tok_ins: {self.tokenizer(tok_ins)}")
            # self.logger.warning(f"tok_res: {self.tokenizer(tok_res)}")
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(
                tokenizer_path if tokenizer_path else path,
                **tokenizer_kwargs,
            )
        if self.tokenizer.pad_token_id is None:
            self.logger.warning(
                "pad_token_id is not set for the tokenizer. "
                "Using eos_token_id as pad_token_id."
            )
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # A patch for llama when batch_padding = True
        if "decapoda-research/llama" in path or (
            tokenizer_path
            and "decapoda-research/llama" in tokenizer_path
        ):
            self.logger.warning(
                "We set new pad_token_id for LLaMA model"
            )
            # keep consistent with official LLaMA repo
            # https://github.com/google/sentencepiece/blob/master/python/sentencepiece_python_module_example.ipynb  # noqa
            self.tokenizer.bos_token = "<s>"
            self.tokenizer.eos_token = "</s>"
            self.tokenizer.pad_token_id = 0

    def _load_model(
        self,
        path: str,
        model_kwargs: dict,
        peft_path: Optional[str] = None,
    ):
        from transformers import AutoModel

        model_kwargs.setdefault("torch_dtype", torch.float16)
        self.model = AutoModel.from_pretrained(path, **model_kwargs)
        if peft_path is not None:
            from peft import PeftModel

            self.model = PeftModel.from_pretrained(
                self.model, peft_path, is_trainable=False
            )
        self.model.eval()

        # A patch for llama when batch_padding = True
        if "decapoda-research/llama" in path:
            self.model.config.bos_token_id = 1
            self.model.config.eos_token_id = 2
            self.model.config.pad_token_id = self.tokenizer.pad_token_id

    def generate(
        self, inputs: List[str], max_out_len: int
    ) -> List[str]:
        """Generate results given a list of inputs.

        Args:
            inputs (List[str]): A list of strings.
            max_out_len (int): The maximum length of the output.

        Returns:
            List[str]: A list of generated strings.
        """
        if self.batch_padding and len(inputs) > 1:
            return self._batch_generate(
                inputs=inputs, max_out_len=max_out_len
            )
        else:
            return sum(
                (
                    self._single_generate(
                        inputs=[input_], max_out_len=max_out_len
                    )
                    for input_ in inputs
                ),
                [],
            )

    def _batch_generate(
        self, inputs: List[str], max_out_len: int
    ) -> List[str]:
        """Support for batch prompts inference.

        Args:
            inputs (List[str]): A list of strings.
            max_out_len (int): The maximum length of the output.

        Returns:
            List[str]: A list of generated strings.
        """
        if self.extract_pred_after_decode:
            prompt_lens = [len(input_) for input_ in inputs]

        # step-1: tokenize the input with batch_encode_plus
        tokens = self.tokenizer.batch_encode_plus(
            inputs,
            padding=True,
            truncation=True,
            max_length=self.max_seq_len - max_out_len,
        )
        tokens = {
            k: torch.tensor(
                np.array(tokens[k]), device=self.model.device
            )
            for k in tokens
            if k in ["input_ids", "attention_mask"]
        }
        # step-2: conduct model forward to generate output
        outputs = self.model.generate(
            **tokens,
            max_new_tokens=max_out_len,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
        )
        if not self.extract_pred_after_decode:
            outputs = outputs[:, tokens["input_ids"].shape[1] :]

        decodeds = self.tokenizer.batch_decode(
            outputs, skip_special_tokens=True
        )

        if self.extract_pred_after_decode:
            decodeds = [
                token[len_:]
                for token, len_ in zip(decodeds, prompt_lens)
            ]
        return decodeds

    def _single_generate(
        self, inputs: List[str], max_out_len: int
    ) -> List[str]:
        """Support for single prompt inference.

        Args:
            inputs (List[str]): A list of strings.
            max_out_len (int): The maximum length of the output.

        Returns:
            List[str]: A list of generated strings.
        """
        if self.extract_pred_after_decode:
            prompt_lens = [len(input_) for input_ in inputs]
        input_ids = self.tokenizer(
            inputs,
            truncation=True,
            max_length=self.max_seq_len - max_out_len,
        )["input_ids"]
        input_ids = torch.tensor(input_ids, device=self.model.device)
        outputs = self.model.generate(
            input_ids=input_ids,
            max_new_tokens=max_out_len,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
        )

        if not self.extract_pred_after_decode:
            outputs = outputs[:, input_ids.shape[1] :]

        decodeds = self.tokenizer.batch_decode(
            outputs, skip_special_tokens=True
        )

        if self.extract_pred_after_decode:
            decodeds = [
                token[len_:]
                for token, len_ in zip(decodeds, prompt_lens)
            ]
        return decodeds

    def get_logits(self, inputs: List[str]):
        if self.batch_padding and len(inputs) > 1:
            # batch inference
            tokens = self.tokenizer(
                inputs,
                padding=True,
                truncation=True,
                max_length=self.max_seq_len,
            )

            tokens = {
                k: torch.tensor(
                    np.array(tokens[k]), device=self.model.device
                )
                for k in tokens
                if k in ["input_ids", "attention_mask"]
            }
            outputs = self.model(**tokens)

        else:
            input_ids = self.tokenizer(
                inputs,
                padding=False,
                truncation=True,
                max_length=self.max_seq_len,
            )["input_ids"]
            input_ids = torch.tensor(
                input_ids, device=self.model.device
            )
            tokens = {"input_ids": input_ids}

            outputs = self.model(input_ids)
        return outputs[0], {"tokens": tokens}

    def get_ppl(
        self, inputs: List[str], mask_length: Optional[List[int]] = None
    ) -> List[float]:
        """Get perplexity scores given a list of inputs.

        Args:
            inputs (List[str]): A list of strings.
            mask_length (Optional[List[int]]): A list of mask lengths. If
                provided, the perplexity scores will be calculated with the
                first mask_length[i] tokens masked out. It's okay to skip
                its implementation if advanced features in PPLInfernecer is
                not needed.

        Returns:
            List[float]: A list of perplexity scores.
        """
        if self.batch_padding and len(inputs) > 1:
            assert self.tokenizer.pad_token
            return self._get_ppl(inputs, mask_length=mask_length)
        else:
            return np.concatenate(
                [
                    self._get_ppl(
                        inputs=[text], mask_length=mask_length
                    )
                    for text in inputs
                ]
            )

    def _get_ppl(
        self, inputs: List[str], mask_length: Optional[List[int]] = None
    ) -> List[float]:
        """Get perplexity scores given a list of inputs.

        Args:
            inputs (List[str]): A list of strings.
            mask_length (Optional[List[int]]): A list of mask lengths. If
                provided, the perplexity scores will be calculated with the
                first mask_length[i] tokens masked out. It's okay to skip
                its implementation if advanced features in PPLInfernecer is
                not needed.

        Returns:
            List[float]: A list of perplexity scores.
        """

        outputs, inputs = self.get_logits(inputs)
        shift_logits = outputs[..., :-1, :].contiguous()

        shift_labels = inputs["tokens"]["input_ids"][
            ..., 1:
        ].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(
            reduction="none", ignore_index=self.tokenizer.pad_token_id
        )
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
        ).view(shift_labels.size())

        if mask_length is not None:
            mask = torch.zeros_like(shift_labels)  # [batch,seqlen]
            for i in range(len(mask)):
                for j in range(mask_length[i] - 1, len(mask[i])):
                    mask[i][j] = 1
            loss = loss * mask

        lens = (
            (
                inputs["tokens"]["input_ids"]
                != self.tokenizer.pad_token_id
            )
            .sum(-1)
            .cpu()
            .numpy()
        )
        if mask_length is not None:
            lens -= np.array(mask_length)
        ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
        return ce_loss

    def get_token_len(self, prompt: str) -> int:
        """Get lengths of the tokenized strings.

        Args:
            prompt (str): Input string.

        Returns:
            int: Length of the input tokens
        """
        return len(self.tokenizer.encode(prompt))


@MODELS.register_module()
class HuggingFaceCausalLM(HuggingFace):
    """Model wrapper around HuggingFace CausalLM.

    Args:
        path (str): The name or path to HuggingFace's model.
        hf_cache_dir: Set the cache dir to HF model cache dir. If None, it will
            use the env variable HF_MODEL_HUB. Defaults to None.
        max_seq_len (int): The maximum length of the input sequence. Defaults
            to 2048.
        tokenizer_path (str): The path to the tokenizer. Defaults to None.
        tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
            Defaults to {}.
        peft_path (str, optional): The name or path to the HuggingFace's PEFT
            model. If None, the original model will not be converted to PEFT.
            Defaults to None.
        tokenizer_only (bool): If True, only the tokenizer will be initialized.
            Defaults to False.
        model_kwargs (dict): Keyword arguments for the model, used in loader.
            Defaults to dict(device_map='auto').
        meta_template (Dict, optional): The model's meta prompt
            template if needed, in case the requirement of injecting or
            wrapping of any meta instructions.
        batch_padding (bool): If False, inference with be performed in for-loop
            without batch padding.
    """

    def _load_model(
        self,
        path: str,
        model_kwargs: dict,
        peft_path: Optional[str] = None,
    ):
        model_kwargs.setdefault("torch_dtype", torch.float16)
        if "llama" in self.config.model_type.lower():
            self.model = LlamaForCausalLM.from_pretrained(
                path, **model_kwargs
            )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                path, **model_kwargs
            )
        embedding_size = self.model.get_input_embeddings().weight.shape[
            0
        ]
        if len(self.tokenizer) > embedding_size:
            self.model.resize_token_embeddings(len(self.tokenizer))
        if peft_path is not None:
            from peft import PeftModel

            self.model = PeftModel.from_pretrained(
                self.model, peft_path, is_trainable=False
            )
        self.model.eval()


@MODELS.register_module()
class GPTQCausalLM(HuggingFace):
    def _load_model(
        self, path: str, model_kwargs: dict, peft_path: str = None
    ):
        from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig

        quantize_config = BaseQuantizeConfig.from_pretrained(path)
        self.model = AutoGPTQForCausalLM.from_quantized(
            path, quantize_config=quantize_config, **model_kwargs
        )

        self.model.eval()


@MODELS.register_module()
class ExllamaCausalLM(HuggingFace):
    def _load_model(
        self, path: str, model_kwargs: dict, peft_path: str = None
    ):
        self.model = ExllamaHF.from_pretrained(path)
        self.model.eval()

    def get_logits(self, inputs: List[str], mask_length=None):
        if self.batch_padding and len(inputs) > 1:
            # batch inference
            tokens = self.tokenizer(
                inputs,
                padding=True,
                truncation=True,
                max_length=self.max_seq_len,
            )

            tokens = {
                k: torch.tensor(
                    np.array(tokens[k]), device=self.model.device
                )
                for k in tokens
                if k in ["input_ids", "attention_mask"]
            }
            outputs = self.model(**tokens)
        else:
            input_ids = self.tokenizer(
                inputs,
                padding=False,
                truncation=True,
                max_length=self.max_seq_len,
            )["input_ids"]
            input_ids = torch.tensor(
                input_ids, device=self.model.device
            )
            tokens = {"input_ids": input_ids}

            outputs = self.model(
                input_ids,
                labels=input_ids,
                mask_length=mask_length,
                ignore_token=self.tokenizer.pad_token_id,
            )

        return outputs.logits, {"tokens": tokens, "loss": outputs.loss}

    def _get_ppl(
        self, inputs: List[str], mask_length: Optional[List[int]] = None
    ) -> List[float]:
        """Get perplexity scores given a list of inputs.

        Args:
            inputs (List[str]): A list of strings.
            mask_length (Optional[List[int]]): A list of mask lengths. If
                provided, the perplexity scores will be calculated with the
                first mask_length[i] tokens masked out. It's okay to skip
                its implementation if advanced features in PPLInfernecer is
                not needed.

        Returns:
            List[float]: A list of perplexity scores.
        """

        outputs, inputs = self.get_logits(inputs, mask_length)

        if "loss" in inputs and inputs["loss"]:
            return inputs["loss"]

        shift_logits = outputs[..., :-1, :].contiguous()

        shift_labels = inputs["tokens"]["input_ids"][
            ..., 1:
        ].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(
            reduction="none", ignore_index=self.tokenizer.pad_token_id
        )

        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
        ).view(shift_labels.size())

        if mask_length is not None:
            mask = torch.zeros_like(shift_labels)  # [batch,seqlen]
            for i in range(len(mask)):
                for j in range(mask_length[i] - 1, len(mask[i])):
                    mask[i][j] = 1
            loss = loss * mask

        lens = (
            (
                inputs["tokens"]["input_ids"]
                != self.tokenizer.pad_token_id
            )
            .sum(-1)
            .cpu()
            .numpy()
        )
        if mask_length is not None:
            lens -= np.array(mask_length)
        ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
        return ce_loss
